import logging
import os
from Network.network_utils import pytorch_model
import numpy as np
from Record.logging import Logger
from Record.file_management import create_directory
from torch.utils.tensorboard import SummaryWriter
from collections import deque
from tianshou.data import Batch
import time

def add_key(maxlen, k, add_val, add_dict):
    if k not in add_dict:
        add_dict[k] = deque(maxlen=maxlen)
    add_dict[k].append(add_val)

def log_helper(maxlen, record_names, dicts, results, keyvals, name_idx=-1):
    # recursively adds all the keys in record_names to the appropriate dict
    total_seen, total_compute = 0,0
    for k in results.keys():
        if type(results[k]) == Batch:
            other_totals = log_helper(maxlen, record_names, dicts, results[k], keyvals + [k], name_idx=name_idx)
            total_seen, total_compute = total_seen + other_totals[0], total_compute + other_totals[1]
        else:
            if k in record_names.mean:
                add_key(maxlen, "_".join(keyvals + [k]), np.mean(pytorch_model.unwrap(results[k])), dicts[0])
            if k in record_names.complete:
                vals = results[k].reshape(len(results[k]), -1)
                add_key(maxlen, "_".join(keyvals + [k]), np.mean(pytorch_model.unwrap(vals), axis=0), dicts[1])
            if k in record_names.rates: # TODO: messy way of dealing with multiple parents
                if type(name_idx) != list and name_idx == -1:
                    trace = results.trace
                    result = results[k]
                else:
                    trace = results.trace[:, name_idx]
                    result = results[k]
                    if len(result.shape) == 3: 
                        if type(name_idx) != list: # if it is a list, then legnth with correspond
                            result = result[:,0] # remove the key index as it corresponds to the name index
                sum_soft_over = (pytorch_model.unwrap(result) - trace) # must have log_batch contain trace
                sum_soft_over[sum_soft_over<0] = 0
                sum_soft_over = np.mean(np.abs(sum_soft_over), axis=0)
                sum_soft_under = (pytorch_model.unwrap(result) - trace) # must have log_batch contain trace
                sum_soft_under[sum_soft_under>0] = 0
                sum_soft_under = np.mean(np.abs(sum_soft_under), axis=0)
                add_key(maxlen, "_".join(keyvals + [k+ "_FP"]), sum_soft_over, dicts[2])
                add_key(maxlen, "_".join(keyvals + [k+ "_FN"]), sum_soft_under, dicts[2])
            if k == "omit_flags":
                total_seen += len(results[k][0])
                total_compute += 1
    return total_seen, total_compute



class InterLogger(Logger):
# Logs everything in the result that is given as a record_names at some point as a key
# there are three kinds of record_names:
    # mean: records the mean value over batches and feature indices
    # complete: records the mean of the feature indices
    # rates: records the false negative and positives compared with the traces
    def __init__(self, name, name_idx, record_graphs, log_interval, record_names, filename="", denorm=False, wdb_logger=None):
        super().__init__(filename, wdb_logger=wdb_logger)
        self.record_graphs = record_graphs
        if len(record_graphs) != 0:
            full_logdir = os.path.join(create_directory(record_graphs+ "/logs"))
            self.tensorboard_logger = SummaryWriter(log_dir=full_logdir)
        self.maxlen = 100
        self.log_interval = log_interval
        self.type = name
        self.name_idx = name_idx
        self.denorm=denorm
        self.record_names = record_names
        self.log_count = 0
        self.i = 0
        self.per = deque(maxlen=10)
        self.reset()

    def reset(self):
        self.complete_dict = dict()
        self.rates_dict = dict()
        self.mean_dict = dict()
        self.last_logged = time.time()

    def log(self, i, result, no_print=False, intermediate_name=""):
        # i is the itr number
        # result is a batch of output results as key-> dict or value
        # no print suppresses printing
        # intermeidate name is an additional term used for logging losses
        # print("logging", len(result.omit_flags[0]))
        # self.seen.append(np.sum(done_flags.astype(float)))
        # self.i += np.sum(done_flags.astype(float))
        use_name = self.type + intermediate_name
        ts, tc = log_helper(self.maxlen, self.record_names, (self.mean_dict, self.complete_dict, self.rates_dict), result, list(), name_idx =self.name_idx)
        self.i += ts
        self.per.append(ts / tc)
        step=i
        if i % self.log_interval == 0:
            if not no_print:
                print(use_name + f' at {i}, fps {self.log_interval / (time.time() - self.last_logged)} logged {self.log_count} with {self.i} total seen, {np.mean(self.per)} per batch')
                for k in self.mean_dict:
                    print(k + f': {np.mean(self.mean_dict[k])}')
                for k in self.complete_dict:
                    print(k + f': {np.mean(self.complete_dict[k], axis=0)}')
                for k in self.rates_dict:
                    print(k + f': {np.mean(self.rates_dict[k], axis=0)}')
            if len(self.record_graphs) != 0 or self.wdb_logger is not None:
                # adds to the tensorboard logger for graphing
                # log the loss values
                for k in self.mean_dict:
                    if len(self.record_graphs) != 0: self.tensorboard_logger.add_scalar(k, np.mean(self.mean_dict[k]), self.log_count)
                if self.wdb_logger is not None: self.wdb_logger.log({use_name + "_" + k: np.mean(self.mean_dict[k]) for k in self.mean_dict.keys()}, step=step)
                for k in self.complete_dict.keys():
                    filled_dict = dict()
                    for i in range(len(self.complete_dict[k][0])):
                        if len(self.record_graphs) != 0: self.tensorboard_logger.add_scalar(k + str(i), np.mean(self.complete_dict[k], axis=0)[i], self.log_count)
                        filled_dict[use_name + "_" + k + str(i)] = np.mean(self.complete_dict[k], axis=0)[i]
                    if self.wdb_logger is not None: self.wdb_logger.log(filled_dict, step=step)
                for k in self.rates_dict.keys():
                    filled_dict = dict()
                    rdv = self.rates_dict[k] if len(self.rates_dict[k][0].shape) < 2 else np.mean(self.rates_dict[k], axis=-1)
                    rdv = np.mean(rdv, axis=0)
                    for i in range(len(rdv)):
                        if len(self.record_graphs) != 0: self.tensorboard_logger.add_scalar(k + str(i), rdv[i], self.log_count)
                        filled_dict[use_name + "_" + k + str(i)] = self.rdv[i]
                    if self.wdb_logger is not None: self.wdb_logger.log(filled_dict, step=step)
            self.last_logged = time.time()
        self.log_count += 1
